# ------------------------------------------------------------------------------
# Check for balance across various metrics in shift testing bins
# Author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 06_risk_balance_scatter.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(testit) # assert()
library(glue) # glue strings
library(ggplot2)
library(ggthemes) # colorblind
library(ggExtra) # ggMarginal()
library(broom) # tidy
library(stargazer) # Latex tables
library(OneR) # bin()
library(lfe) # felm()

temp <- here::here("code", "05_natural_experiment", "temp")
overnight_lab <- ""

a <- modules::use(here::here("lib", "aesthetics.R"))
u <- modules::use(here::here("lib", "util.R"))
a$get_font("Optima", here("lib", "optima.ttf"))

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here::here("lib", "filepaths.yml"))

shift_test_rates <- readRDS(file.path(temp, "shift_test_rates.rds")) %>%
  filter(split == "train" | split == "test")

# First encounters only
keep_encs <- shift_test_rates %>%
  .[order(.$t0),] %>%
  group_by(ptid) %>%
  summarize(ed_enc_id = ed_enc_id[1]) %>%
  ungroup %>%
  .[["ed_enc_id"]]

shift_test_rates <- filter(shift_test_rates, ed_enc_id %in% keep_encs)

# Y-hat vs. shift RE scatter -------------------------------------------------
message("Getting shift-level yhat-bar and testing RE...")
shift_level_df <- shift_test_rates %>%
  group_by(shift_12) %>%
  summarize(
    shift_yhat_bar = mean(p__ensemble__stent_or_cabg_010_day__tested),
    # shift_12_trate_leaveout = shift_12_trate_leaveout[1],
    shift_12_trate = shift_12_trate[1]
  ) %>%
  ungroup %>%
  mutate(
    t_hat_percentile = ntile(shift_12_trate, 100)
  )

message("Plotting ", nrow(shift_level_df), " shifts...")
# set to TRUE to produce plots with titles
incl_titles <- FALSE
x_axis_ticks <- c(0, 25, 50, 75, 100)
x_axis_labs <- c()
for(tick in x_axis_ticks){
  t_rate <- quantile(shift_level_df$shift_12_trate, tick/100) %>%
    round(digits = 3)
  lab <- glue("{t_rate} ({tick})")
  x_axis_labs <- c(x_axis_labs, lab)
}
gg <- ggplot(
  data = shift_level_df,
  aes(x = t_hat_percentile, y = shift_yhat_bar)
  # aes(x = shift_12_trate, y = shift_yhat_bar)

) +
  geom_point() +
  # geom_smooth(method = "lm", color = "#475C7A") +
  labs(
    x = "Triage Shift Test Rate (Percentile)",
    # x = "Triage Shift Test Rate",
    y = "Triage Shift Mean Predicted Risk"
  ) +
  scale_x_continuous(labels = x_axis_labs) +
  theme_bw()

if(incl_titles){
  gg <- gg + labs(
    title = "Shift-level Mean Predicted Risk by Shift Random Effect",
    subtitle = glue("{re_var}"), caption = glue("N = {nrow(test_cohort)}")
  )
}

gg_marginal <- ggMarginal(
  gg, type = "histogram", fill = "#475C7A", margins = "y"
)

fp <- file.path(temp, "plots", glue("scatter__yhatbar__by__leaveout_trate.png"))
message("Saving yhat-bar scatter to ", fp)

ggsave(plot = gg_marginal, filename = fp, width = 10, height = 7)
# remove Rplots.pdf autogenerated file
fn <- "Rplots.pdf"
if(file.exists(fn)){file.remove(fn)}
# }

message("Done.")
